Bayesian GAM Part1

Author

Murray Logan

Published

06/07/2025

1 Preparations

Load the necessary libraries

library(tidyverse) # for data wrangling etc
library(cmdstanr) # for cmdstan
library(brms) # for fitting models in STAN
library(standist) # for exploring distributions
library(coda) # for diagnostics
library(bayesplot) # for diagnostics
library(ggmcmc) # for MCMC diagnostics
library(DHARMa) # for residual diagnostics
library(rstan) # for interfacing with STAN
library(emmeans) # for marginal means etc
library(broom) # for tidying outputs
library(tidybayes) # for more tidying outputs
library(HDInterval) # for HPD intervals
library(ggeffects) # for partial plots
library(broom.mixed) # for summarising models
library(posterior) # for posterior draws
library(ggeffects) # for partial effects plots
library(patchwork) # for multi-panel figures
library(bayestestR) # for ROPE
library(see) # for some plots
library(easystats) # framework for stats, modelling and visualisation
library(mgcv)
library(gratia)
theme_set(theme_grey()) # put the default ggplot theme back
source("helperFunctions.R")

2 Scenario

This is an entirely fabricated example (how embarrising). So here is a picture of some Red Grouse Chicks to compensate..

Figure 1: Red grouse chicks
Table 1: Format of data.gp.csv data file
x y
2 3
4 5
8 6
10 7
14 4
Table 2: Format of data.gp.csv data file
x - a continuous predictor
y - a continuous response

3 Read in the data

data_gam <- read_csv("../data/data_gam.csv", trim_ws = TRUE)
Rows: 5 Columns: 2
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
dbl (2): x, y

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
data_gam |> glimpse()
Rows: 5
Columns: 2
$ x <dbl> 2, 4, 8, 10, 14
$ y <dbl> 3, 5, 6, 7, 4
## Explore the first 6 rows of the data
data_gam |> head()
# A tibble: 5 × 2
      x     y
  <dbl> <dbl>
1     2     3
2     4     5
3     8     6
4    10     7
5    14     4
data_gam |> str()
spc_tbl_ [5 × 2] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
 $ x: num [1:5] 2 4 8 10 14
 $ y: num [1:5] 3 5 6 7 4
 - attr(*, "spec")=
  .. cols(
  ..   x = col_double(),
  ..   y = col_double()
  .. )
 - attr(*, "problems")=<externalptr> 
data_gam |> datawizard::data_codebook()
data_gam (5 rows and 2 variables, 2 shown)

ID | Name | Type    | Missings | Values |         N
---+------+---------+----------+--------+----------
1  | x    | numeric | 0 (0.0%) |      2 | 1 (20.0%)
   |      |         |          |      4 | 1 (20.0%)
   |      |         |          |      8 | 1 (20.0%)
   |      |         |          |     10 | 1 (20.0%)
   |      |         |          |     14 | 1 (20.0%)
---+------+---------+----------+--------+----------
2  | y    | numeric | 0 (0.0%) |      3 | 1 (20.0%)
   |      |         |          |      4 | 1 (20.0%)
   |      |         |          |      5 | 1 (20.0%)
   |      |         |          |      6 | 1 (20.0%)
   |      |         |          |      7 | 1 (20.0%)
---------------------------------------------------

4 Exploratory data analysis

Model formula: \[ \begin{align} y_i &\sim{} \mathcal{N}(\mu_i, \sigma^2)\\ \mu_i &=\beta_0 + f(x_i)\\ f(x_i) &= \sum^k_{j=1}{b_j(x_i)\beta_j} \end{align} \]

where \(\beta_0\) is the y-intercept, and \(f(x)\) indicates an additive smoothing function of \(x\).

Although this is a ficticious example without a clear backstory, given that there are two continous predictors (and that one of these has been identified as a response and the other a predictor), we can assume that we might be interested in investigating the relationship between the two. As such, our typically starting point is to explore the basic trend between the two using a scatterplot.

ggplot(data_gam, aes(y = y, x = x)) +
  geom_point() +
  geom_line()

This does not look like a particularly linear relationship. Lets fit a loess smoother..

ggplot(data_gam, aes(y = y, x = x)) +
  geom_point() +
  geom_smooth()

And what would a linear smoother look like?

ggplot(data_gam, aes(y = y, x = x)) +
  geom_point() +
  geom_smooth(method = "lm")

Rather than either a loess or linear smoother, we can also try a Generalized Additive Model (GAM) smoother. Dont pay too much attention to the GAM formula at this stage, this will be discussed later in the Model Fitting section.

ggplot(data_gam, aes(y = y, x = x)) +
  geom_point() +
  geom_smooth(method = "gam", formula = y ~ s(x, k = 3))

Conclusions:

  • it is clear that the relationship is not linear.
  • it does appear that as x inreases, y initially increases before eventually declining again.
  • we could model this with a polynomial, but for this exemple, we will use these data to illustrate the fitting of GAMs.

5 Fit the model

Prior to fitting the GAM, it might be worth gaining a bit of an understanding of what will occur behind the scenes.

Lets say we intended to fit a smoother with three knots. The three knots equate to one at each end of the trend and one in the middle. We could reexpress our predictor (x) as three dummy variables that collectively reflect a spline (in this case, potentially two joined polynomials).

data.frame(smoothCon(s(x, k = 3), data = data_gam)[[1]]$X) %>%
  bind_cols(data_gam)
         X1 X2          X3  x y
1 1.3554342  1 -1.31122014  2 3
2 0.9289363  1 -0.84292723  4 5
3 0.4755086  1  0.09365858  8 6
4 0.5780165  1  0.56195149 10 7
5 1.3189632  1  1.49853730 14 4

And we could visualize these dummies as three separate components.

basis(s(x, k = 3), data = data_gam) %>% draw()

basis(s(x, k = 3, bs = "cr"), data = data_gam) %>% draw()

OR

newdata <-
  data.frame(smoothCon(s(x, k = 3), data = data_gam)[[1]]$X) %>%
  bind_cols(data_gam)
ggplot(newdata, aes(x = x)) +
  geom_line(aes(y = X1)) +
  geom_line(aes(y = X2)) +
  geom_line(aes(y = X3))

brms follows the same basic process as gamm4. That is the smooths are partitioned into two components:

  • a penalised component which is treated as a random effect
  • an unpenalised component that is treated as a fixed effect

The wiggliness penalty matrix is the precision matrix when the smooth is treated as a random effect The smoothness of a term is determined by estimating the variance of the term

In brms, the default priors are designed to be weakly informative. They are chosen to provide moderate regularisation (to help prevent over-fitting) and help stabilise the computations.

Unlike rstanarm, brms models must be compiled before they start sampling. For most models, the compilation of the stan code takes around 45 seconds.

data_gam.form <- bf(y ~ s(x), family = gaussian())
data_gam.brm <- brm(data_gam.form,
  data = data_gam,
  iter = 5000,
  warmup = 1000,
  chains = 3,
  thin = 5,
  refresh = 0,
  backend = "rstan"
)
Error in smooth.construct.tp.smooth.spec(object, dk$data, dk$knots): A term has fewer unique covariate combinations than specified maximum degrees of freedom
data_gam.form <- bf(y ~ s(x, k = 3), family = gaussian())
get_prior(data_gam.form, data = data_gam)
                prior     class        coef group resp dpar nlpar lb ub
               (flat)         b                                        
               (flat)         b        sx_1                            
 student_t(3, 5, 2.5) Intercept                                        
 student_t(3, 0, 2.5)       sds                                    0   
 student_t(3, 0, 2.5)       sds s(x, k = 3)                        0   
 student_t(3, 0, 2.5)     sigma                                    0   
       source
      default
 (vectorized)
      default
      default
 (vectorized)
      default
data_gam.brm <- brm(data_gam.form,
  data = data_gam,
  iter = 5000,
  warmup = 1000,
  chains = 3, cores = 3,
  thin = 5,
  refresh = 0,
  backend = "rstan"
)
Compiling Stan program...
Start sampling
Warning: There were 43 divergent transitions after warmup. See
https://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
to find out why this is a problem and how to eliminate them.
Warning: Examine the pairs() plot to diagnose sampling problems
Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
Running the chains for more iterations may help. See
https://mc-stan.org/misc/warnings.html#tail-ess
prior_summary(data_gam.brm)
                prior     class        coef group resp dpar nlpar lb ub       source
               (flat)         b                                              default
               (flat)         b        sx_1                             (vectorized)
 student_t(3, 5, 2.5) Intercept                                              default
 student_t(3, 0, 2.5)       sds                                    0         default
 student_t(3, 0, 2.5)       sds s(x, k = 3)                        0    (vectorized)
 student_t(3, 0, 2.5)     sigma                                    0         default

sds - standard devation of the wiggly basis function

data_gam.brm <- brm(data_gam.form,
  data = data_gam,
  prior = prior(normal(0, 2.5), class = "b"),
  sample_prior = "only",
  iter = 5000,
  warmup = 1000,
  chains = 3,
  thin = 5,
  backend = "rstan",
  refresh = 0
)
Compiling Stan program...
Start sampling

conditional_effects

data_gam.brm |>
  conditional_effects() |>
  plot(points = TRUE)

The following link provides some guidance about defining priors. [https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations]

When defining our own priors, we typically do not want them to be scaled.

If we wanted to define our own priors that were less vague, yet still not likely to bias the outcomes, we could try the following priors (which I have mainly plucked out of thin air):

  • \(\beta_0\): normal centred at 164 with a standard deviation of 65
    • mean of 164: since median(fert$YIELD)
    • sd pf 65: since mad(fert$YIELD)
  • \(\beta_1\): normal centred at 0 with a standard deviation of 2.5
    • sd of 2.5: since 2.5*(mad(fert$YIELD)/mad(fert$FERTILIZER))
  • \(\sigma\): half-t centred at 0 with a standard deviation of 65 OR
    • sd pf 65: since mad(fert$YIELD)
  • \(\sigma\): gamma with shape parameters of 2 and 1

Sample prior only

I will also overlay the raw data for comparison.

data_gam |> summarise(median(y), mad(y))
# A tibble: 1 × 2
  `median(y)` `mad(y)`
        <dbl>    <dbl>
1           5     1.48
priors <- prior(normal(5, 1.5), class = "Intercept") +
  prior(normal(0, 1.5), class = "b") +
  prior(student_t(3, 0, 1.5), class = "sigma") +
  prior(student_t(3, 0, 10), class = "sds")

data_gam.form <- bf(y ~ s(x, k = 3))
data_gam.brm2 <- brm(data_gam.form,
  data = data_gam,
  prior = priors,
  sample_prior = "only",
  iter = 5000,
  warmup = 1000,
  chains = 3, cores = 3,
  thin = 5,
  backend = "rstan",
  control = list(adapt_delta = 0.99),
  refresh = 0
)
Compiling Stan program...
Start sampling
data_gam.brm2 |>
  conditional_effects() |>
  plot(points = TRUE)

data_gam.brm2 |>
  conditional_smooths() |>
  plot()

Sample prior and posterior

data_gam.brm3 <- update(data_gam.brm2, sample_prior = "yes", cores = 3, refresh = 0)
The desired updates require recompiling the model
Compiling Stan program...
Start sampling
Warning: There were 1 divergent transitions after warmup. See
https://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
to find out why this is a problem and how to eliminate them.
Warning: Examine the pairs() plot to diagnose sampling problems
data_gam.brm3 |>
  conditional_effects() |>
  plot(points = TRUE)

data_gam.brm3 |>
  conditional_effects(spaghetti = TRUE, ndraws = 200) |>
  plot(points = TRUE)

data_gam.brm3 |> get_variables()
 [1] "b_Intercept"     "bs_sx_1"         "sds_sx_1"        "sigma"          
 [5] "Intercept"       "s_sx_1[1]"       "prior_Intercept" "prior_bs"       
 [9] "prior_sds_sx"    "prior_sigma"     "lprior"          "lp__"           
[13] "accept_stat__"   "stepsize__"      "treedepth__"     "n_leapfrog__"   
[17] "divergent__"     "energy__"       
data_gam.brm3 |>
  hypothesis("bs_sx_1 = 0", class = "") |>
  plot()

data_gam.brm3 |>
  hypothesis("sds_sx_1 = 0", class = "") |>
  plot()

data_gam.brm3 |>
  hypothesis("sigma = 0", class = "") |>
  plot()

data_gam.brm3 |> SUYR_prior_and_posterior()
Error in model.matrix.default(f, dat): model frame and formula mismatch in model.matrix()

6 MCMC sampling diagnostics

data_gam.brm3$fit |> stan_trace()
'pars' not specified. Showing first 10 parameters by default.

data_gam.brm3$fit |> stan_ac()
'pars' not specified. Showing first 10 parameters by default.

data_gam.brm3$fit |> stan_rhat()
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

data_gam.brm3$fit |> stan_ess()
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

7 Model validation

data_gam.brm3 |> pp_check(type = "dens_overlay", ndraws = 100)

data_gam.resids <- make_brms_dharma_res(data_gam.brm3, integerResponse = FALSE)
wrap_elements(~ testUniformity(data_gam.resids)) +
  wrap_elements(~ plotResiduals(data_gam.resids, form = factor(rep(1, nrow(data_gam))))) +
  wrap_elements(~ plotResiduals(data_gam.resids, quantreg = FALSE)) +
  wrap_elements(~ testDispersion(data_gam.resids))
Warning in smooth.spline(pred, res, df = 10): not using invalid df; must have 1
< df <= n := #{unique x} = 5

testDispersion(data_gam.resids)


    DHARMa nonparametric dispersion test via sd of residuals fitted vs.
    simulated

data:  simulationOutput
dispersion = 0.11431, p-value = 0.25
alternative hypothesis: two.sided

8 Partial effects plots

data_gam.brm3 |>
  conditional_effects() |>
  plot(points = TRUE)

data_gam.brm3 |>
  conditional_effects(spaghetti = TRUE, ndraws = 250) |>
  plot(points = TRUE)

9 Model investigation

data_gam.brm3 |> summary()
Warning: There were 1 divergent transitions after warmup. Increasing
adapt_delta above 0.99 may help. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: y ~ s(x, k = 3) 
   Data: data_gam (Number of observations: 5) 
  Draws: 3 chains, each with iter = 5000; warmup = 1000; thin = 5;
         total post-warmup draws = 2400

Smoothing Spline Hyperparameters:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sds(sx_1)    12.84      9.57     1.09    37.47 1.00     1670     1303

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept     5.01      0.49     3.97     6.06 1.00     2292     2049
sx_1          0.30      0.48    -0.75     1.27 1.00     2486     2179

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     1.10      0.62     0.39     2.66 1.00     1366     2137

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
  • sds(sx_1) is the sd of the smooth weights (spline coefficients). This determines the amount of ‘wiggliness’, in an analogous way to how the sd of group-level effects in a varying slopes and intercepts model determine the amount of variability among groups in slopes and intercepts. However, the actual numeric value of the sds() is not very practically interpretable, because thinking about the variance of smooth weights for any given data and model seems abstract to me. However, if the value is around zero, then this is like ‘complete-pooling’ of the basis functions, which means that there isn’t much added value of more than a single basis function.

  • sx_1 is the unpenalized weight (ie coefficient) for one of the “natural” parameterized basis functions. The rest of the basis functions are like varying effects. Again, because the actual numeric value of sxs_1 is the value for the unpenalized coefficient for one of the basis functions, this wouldn’t seem to have a lot of practically interpretable meaning just from viewing this number.

data_gam.brm3 |> get_variables()
 [1] "b_Intercept"     "bs_sx_1"         "sds_sx_1"        "sigma"          
 [5] "Intercept"       "s_sx_1[1]"       "prior_Intercept" "prior_bs"       
 [9] "prior_sds_sx"    "prior_sigma"     "lprior"          "lp__"           
[13] "accept_stat__"   "stepsize__"      "treedepth__"     "n_leapfrog__"   
[17] "divergent__"     "energy__"       
data_gam.brm3 |>
  as_draws_df() |>
  dplyr::select(matches("^b_.*|^bs.*|^sds.*|^sigma$|^s_s.*")) |>
  summarise_draws(median,
    HDInterval::hdi,
    Pl = ~ mean(.x < 0),
    Pg = ~ mean(.x > 0)
  )
Warning: Dropping 'draws_df' class as required metadata was removed.
# A tibble: 5 × 6
  variable    median    lower upper     Pl    Pg
  <chr>        <dbl>    <dbl> <dbl>  <dbl> <dbl>
1 b_Intercept  5.00   3.92     5.96 0      1    
2 bs_sx_1      0.321 -0.819    1.18 0.211  0.789
3 sds_sx_1    10.9    0.00208 30.0  0      1    
4 sigma        0.932  0.301    2.31 0      1    
5 s_sx_1[1]   10.9   -2.07    19.1  0.0633 0.937

10 Further analyses

newdata <- with(data_gam, data.frame(x = c(min(x), 9)))
add_epred_draws(
  object = data_gam.brm3, newdata = newdata,
  ndraws = 2400
) |>
  ungroup() |>
  group_by(.draw) |>
  summarise(Diff = diff(.epred)) |>
  summarise(median_hdci(Diff),
    Pl = mean(Diff < 0),
    Pg = mean(Diff > 0)
  )
# A tibble: 1 × 8
      y   ymin  ymax .width .point .interval     Pl    Pg
  <dbl>  <dbl> <dbl>  <dbl> <chr>  <chr>      <dbl> <dbl>
1  3.02 -0.661  5.20   0.95 median hdci      0.0525 0.948
newdata <- with(data_gam, data.frame(x = seq(min(x), max(x), length = 1000)))
data_gam.peak <-
  add_epred_draws(object = data_gam.brm3, newdata = newdata, ndraws = 1000) |>
  ungroup() |>
  group_by(.draw) |>
  # summarise(x = x[which.max(.epred)]) |>
  mutate(diff = .epred - lag(.epred)) |>
  summarise(x = x[which.min(abs(diff))]) |>
  median_hdci(x, .width = 0.95)
data_gam.peak
# A tibble: 1 × 6
      x .lower .upper .width .point .interval
  <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
1  8.52   4.02     14   0.95 median hdci     
## lets plot this
data_gam.preds <-
  data_gam.brm3 |>
  add_epred_draws(newdata = newdata, object = _) |>
  ungroup() |>
  dplyr::select(-.row, -.chain, -.iteration) |>
  group_by(x) |>
  summarise_draws(median, HDInterval::hdi) |>
  ungroup() |>
  mutate(
    Flag = between(x, data_gam.peak$.lower, data_gam.peak$.upper),
    Grp = data.table::rleid(Flag)
  )
data_gam.preds |> head()
# A tibble: 6 × 7
      x variable median lower upper Flag    Grp
  <dbl> <chr>     <dbl> <dbl> <dbl> <lgl> <int>
1  2    .epred     3.36  1.66  5.73 FALSE     1
2  2.01 .epred     3.37  1.67  5.73 FALSE     1
3  2.02 .epred     3.37  1.68  5.73 FALSE     1
4  2.04 .epred     3.38  1.68  5.72 FALSE     1
5  2.05 .epred     3.39  1.69  5.72 FALSE     1
6  2.06 .epred     3.40  1.70  5.72 FALSE     1
ggplot(data_gam.preds, aes(y = median, x = x)) +
  geom_line(aes(colour = Flag, group = Grp)) +
  geom_ribbon(aes(ymin = lower, ymax = upper, fill = Flag, group = Grp), alpha = 0.2)

Unfortunately, it does not appear that this option provides confidence intervals.

data_gam.brm3 |>
  estimate_relation(keep_iterations = TRUE, length = 1000) |>
  estimate_smooth(x = "x")
Start |   End | Length | Change | Slope |   R2
----------------------------------------------
2.00  |  8.00 |   0.34 |   2.71 |  0.45 | 0.12
8.00  | 14.00 |   0.46 |  -1.76 | -0.29 | 0.12
newdata <- with(data_gam, data.frame(x = seq(min(x), max(x), length = 1000)))
data_gam.brm3 |>
  add_epred_draws(newdata = newdata, object = _) |>
  ungroup() |>
  group_by(.draw) |>
  mutate(diff = .epred - lag(.epred)) |>
  summarise(
    maxGrad = max(abs(diff), na.rm = TRUE),
    x = x[which.max(diff)]
  ) |>
  summarise_draws(median, HDInterval::hdi)
# A tibble: 2 × 4
  variable  median    lower   upper
  <chr>      <dbl>    <dbl>   <dbl>
1 maxGrad  0.00884 0.000737  0.0144
2 x        2.01    2.01     14     
newdata <- with(data_gam, data.frame(x = seq(min(x), max(x), length = 1000)))
data_gam.brm3 |>
  add_epred_draws(newdata = newdata, object = _) |>
  filter(x > 3, x < 13) |>
  ungroup() |>
  group_by(.draw) |>
  mutate(
    diff = .epred - lag(.epred),
    diff2 = diff - lag(diff)
  ) |>
  summarise(
    maxChange = max(abs(diff2), na.rm = TRUE),
    x = x[which.max(diff)]
  ) |>
  summarise_draws(median, HDInterval::hdi)
# A tibble: 2 × 4
  variable     median         lower      upper
  <chr>         <dbl>         <dbl>      <dbl>
1 maxChange 0.0000260 0.00000000545  0.0000430
2 x         3.02      3.02          13.0